{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## CART算法\n",
    "\n",
    "英文全称叫做Classfication And Regression Tree, 中文叫做分类回归树\n",
    "ID3和C4.5算法可以生成二叉树或者多叉树， 而CART算法只支持而擦函数；同时，CART决策树比较特殊， 既可以作为分类树， 也可以作为回归树"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "假设t为节点， 那么该节点的GINI系数的计算公式为\n",
    "\n",
    "$$GINI(t) = 1 - \\sum{k}{}{[p(C_k \\mid t)]^2}$$\n",
    "\n",
    "$p(C_k \\mid t)$代表节点t属于类别$C_k$的概率， 节点t的基尼系数为1减去各类别$C_k$的概率平方和"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CART分类树准确率 0.9600\n"
     ]
    },
    {
     "data": {
      "image/svg+xml": [
       "<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\r\n",
       "<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\r\n",
       " \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\r\n",
       "<!-- Generated by graphviz version 2.38.0 (20140413.2041)\r\n",
       " -->\r\n",
       "<!-- Title: Tree Pages: 1 -->\r\n",
       "<svg width=\"499pt\" height=\"477pt\"\r\n",
       " viewBox=\"0.00 0.00 499.00 477.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\r\n",
       "<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 473)\">\r\n",
       "<title>Tree</title>\r\n",
       "<polygon fill=\"white\" stroke=\"none\" points=\"-4,4 -4,-473 495,-473 495,4 -4,4\"/>\r\n",
       "<!-- 0 -->\r\n",
       "<g id=\"node1\" class=\"node\"><title>0</title>\r\n",
       "<polygon fill=\"none\" stroke=\"black\" points=\"240.5,-469 115.5,-469 115.5,-401 240.5,-401 240.5,-469\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"178\" y=\"-453.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">X[3] &lt;= 0.75</text>\r\n",
       "<text text-anchor=\"middle\" x=\"178\" y=\"-438.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">gini = 0.666</text>\r\n",
       "<text text-anchor=\"middle\" x=\"178\" y=\"-423.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">samples = 100</text>\r\n",
       "<text text-anchor=\"middle\" x=\"178\" y=\"-408.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">value = [34, 31, 35]</text>\r\n",
       "</g>\r\n",
       "<!-- 1 -->\r\n",
       "<g id=\"node2\" class=\"node\"><title>1</title>\r\n",
       "<polygon fill=\"none\" stroke=\"black\" points=\"167,-357.5 55,-357.5 55,-304.5 167,-304.5 167,-357.5\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"111\" y=\"-342.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">gini = 0.0</text>\r\n",
       "<text text-anchor=\"middle\" x=\"111\" y=\"-327.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">samples = 34</text>\r\n",
       "<text text-anchor=\"middle\" x=\"111\" y=\"-312.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">value = [34, 0, 0]</text>\r\n",
       "</g>\r\n",
       "<!-- 0&#45;&gt;1 -->\r\n",
       "<g id=\"edge1\" class=\"edge\"><title>0&#45;&gt;1</title>\r\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M156.247,-400.884C148.951,-389.776 140.803,-377.372 133.454,-366.184\"/>\r\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"136.179,-363.957 127.763,-357.52 130.328,-367.8 136.179,-363.957\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"122.693\" y=\"-378.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">True</text>\r\n",
       "</g>\r\n",
       "<!-- 2 -->\r\n",
       "<g id=\"node3\" class=\"node\"><title>2</title>\r\n",
       "<polygon fill=\"none\" stroke=\"black\" points=\"304.5,-365 185.5,-365 185.5,-297 304.5,-297 304.5,-365\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"245\" y=\"-349.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">X[2] &lt;= 4.95</text>\r\n",
       "<text text-anchor=\"middle\" x=\"245\" y=\"-334.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">gini = 0.498</text>\r\n",
       "<text text-anchor=\"middle\" x=\"245\" y=\"-319.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">samples = 66</text>\r\n",
       "<text text-anchor=\"middle\" x=\"245\" y=\"-304.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">value = [0, 31, 35]</text>\r\n",
       "</g>\r\n",
       "<!-- 0&#45;&gt;2 -->\r\n",
       "<g id=\"edge2\" class=\"edge\"><title>0&#45;&gt;2</title>\r\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M199.753,-400.884C205.428,-392.243 211.619,-382.819 217.548,-373.793\"/>\r\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"220.562,-375.579 223.127,-365.299 214.711,-371.736 220.562,-375.579\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"228.197\" y=\"-386.08\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">False</text>\r\n",
       "</g>\r\n",
       "<!-- 3 -->\r\n",
       "<g id=\"node4\" class=\"node\"><title>3</title>\r\n",
       "<polygon fill=\"none\" stroke=\"black\" points=\"236,-261 124,-261 124,-193 236,-193 236,-261\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"180\" y=\"-245.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">X[3] &lt;= 1.65</text>\r\n",
       "<text text-anchor=\"middle\" x=\"180\" y=\"-230.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">gini = 0.165</text>\r\n",
       "<text text-anchor=\"middle\" x=\"180\" y=\"-215.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">samples = 33</text>\r\n",
       "<text text-anchor=\"middle\" x=\"180\" y=\"-200.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">value = [0, 30, 3]</text>\r\n",
       "</g>\r\n",
       "<!-- 2&#45;&gt;3 -->\r\n",
       "<g id=\"edge3\" class=\"edge\"><title>2&#45;&gt;3</title>\r\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M223.896,-296.884C218.39,-288.243 212.385,-278.819 206.633,-269.793\"/>\r\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"209.546,-267.852 201.22,-261.299 203.643,-271.614 209.546,-267.852\"/>\r\n",
       "</g>\r\n",
       "<!-- 8 -->\r\n",
       "<g id=\"node9\" class=\"node\"><title>8</title>\r\n",
       "<polygon fill=\"none\" stroke=\"black\" points=\"366,-261 254,-261 254,-193 366,-193 366,-261\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"310\" y=\"-245.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">X[2] &lt;= 5.05</text>\r\n",
       "<text text-anchor=\"middle\" x=\"310\" y=\"-230.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">gini = 0.059</text>\r\n",
       "<text text-anchor=\"middle\" x=\"310\" y=\"-215.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">samples = 33</text>\r\n",
       "<text text-anchor=\"middle\" x=\"310\" y=\"-200.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">value = [0, 1, 32]</text>\r\n",
       "</g>\r\n",
       "<!-- 2&#45;&gt;8 -->\r\n",
       "<g id=\"edge8\" class=\"edge\"><title>2&#45;&gt;8</title>\r\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M266.104,-296.884C271.61,-288.243 277.615,-278.819 283.367,-269.793\"/>\r\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"286.357,-271.614 288.78,-261.299 280.454,-267.852 286.357,-271.614\"/>\r\n",
       "</g>\r\n",
       "<!-- 4 -->\r\n",
       "<g id=\"node5\" class=\"node\"><title>4</title>\r\n",
       "<polygon fill=\"none\" stroke=\"black\" points=\"112,-149.5 7.10543e-015,-149.5 7.10543e-015,-96.5 112,-96.5 112,-149.5\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"56\" y=\"-134.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">gini = 0.0</text>\r\n",
       "<text text-anchor=\"middle\" x=\"56\" y=\"-119.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">samples = 29</text>\r\n",
       "<text text-anchor=\"middle\" x=\"56\" y=\"-104.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">value = [0, 29, 0]</text>\r\n",
       "</g>\r\n",
       "<!-- 3&#45;&gt;4 -->\r\n",
       "<g id=\"edge4\" class=\"edge\"><title>3&#45;&gt;4</title>\r\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M139.741,-192.884C125.301,-181.006 109.06,-167.646 94.7512,-155.876\"/>\r\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"96.9706,-153.17 87.0242,-149.52 92.5237,-158.576 96.9706,-153.17\"/>\r\n",
       "</g>\r\n",
       "<!-- 5 -->\r\n",
       "<g id=\"node6\" class=\"node\"><title>5</title>\r\n",
       "<polygon fill=\"none\" stroke=\"black\" points=\"235.5,-157 130.5,-157 130.5,-89 235.5,-89 235.5,-157\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"183\" y=\"-141.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">X[1] &lt;= 3.1</text>\r\n",
       "<text text-anchor=\"middle\" x=\"183\" y=\"-126.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">gini = 0.375</text>\r\n",
       "<text text-anchor=\"middle\" x=\"183\" y=\"-111.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">samples = 4</text>\r\n",
       "<text text-anchor=\"middle\" x=\"183\" y=\"-96.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">value = [0, 1, 3]</text>\r\n",
       "</g>\r\n",
       "<!-- 3&#45;&gt;5 -->\r\n",
       "<g id=\"edge5\" class=\"edge\"><title>3&#45;&gt;5</title>\r\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M180.974,-192.884C181.212,-184.778 181.471,-175.982 181.721,-167.472\"/>\r\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"185.225,-167.398 182.021,-157.299 178.228,-167.192 185.225,-167.398\"/>\r\n",
       "</g>\r\n",
       "<!-- 6 -->\r\n",
       "<g id=\"node7\" class=\"node\"><title>6</title>\r\n",
       "<polygon fill=\"none\" stroke=\"black\" points=\"112.5,-53 7.5,-53 7.5,-0 112.5,-0 112.5,-53\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"60\" y=\"-37.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">gini = 0.0</text>\r\n",
       "<text text-anchor=\"middle\" x=\"60\" y=\"-22.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">samples = 3</text>\r\n",
       "<text text-anchor=\"middle\" x=\"60\" y=\"-7.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">value = [0, 0, 3]</text>\r\n",
       "</g>\r\n",
       "<!-- 5&#45;&gt;6 -->\r\n",
       "<g id=\"edge6\" class=\"edge\"><title>5&#45;&gt;6</title>\r\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M139.981,-88.9485C127.478,-79.3431 113.853,-68.8747 101.449,-59.345\"/>\r\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"103.561,-56.5537 93.4986,-53.2367 99.2961,-62.1046 103.561,-56.5537\"/>\r\n",
       "</g>\r\n",
       "<!-- 7 -->\r\n",
       "<g id=\"node8\" class=\"node\"><title>7</title>\r\n",
       "<polygon fill=\"none\" stroke=\"black\" points=\"235.5,-53 130.5,-53 130.5,-0 235.5,-0 235.5,-53\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"183\" y=\"-37.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">gini = 0.0</text>\r\n",
       "<text text-anchor=\"middle\" x=\"183\" y=\"-22.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">samples = 1</text>\r\n",
       "<text text-anchor=\"middle\" x=\"183\" y=\"-7.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">value = [0, 1, 0]</text>\r\n",
       "</g>\r\n",
       "<!-- 5&#45;&gt;7 -->\r\n",
       "<g id=\"edge7\" class=\"edge\"><title>5&#45;&gt;7</title>\r\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M183,-88.9485C183,-80.7153 183,-71.848 183,-63.4814\"/>\r\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"186.5,-63.2367 183,-53.2367 179.5,-63.2367 186.5,-63.2367\"/>\r\n",
       "</g>\r\n",
       "<!-- 9 -->\r\n",
       "<g id=\"node10\" class=\"node\"><title>9</title>\r\n",
       "<polygon fill=\"none\" stroke=\"black\" points=\"360.5,-157 255.5,-157 255.5,-89 360.5,-89 360.5,-157\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"308\" y=\"-141.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">X[3] &lt;= 1.8</text>\r\n",
       "<text text-anchor=\"middle\" x=\"308\" y=\"-126.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">gini = 0.444</text>\r\n",
       "<text text-anchor=\"middle\" x=\"308\" y=\"-111.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">samples = 3</text>\r\n",
       "<text text-anchor=\"middle\" x=\"308\" y=\"-96.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">value = [0, 1, 2]</text>\r\n",
       "</g>\r\n",
       "<!-- 8&#45;&gt;9 -->\r\n",
       "<g id=\"edge9\" class=\"edge\"><title>8&#45;&gt;9</title>\r\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M309.351,-192.884C309.192,-184.778 309.019,-175.982 308.852,-167.472\"/>\r\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"312.348,-167.229 308.653,-157.299 305.35,-167.366 312.348,-167.229\"/>\r\n",
       "</g>\r\n",
       "<!-- 12 -->\r\n",
       "<g id=\"node13\" class=\"node\"><title>12</title>\r\n",
       "<polygon fill=\"none\" stroke=\"black\" points=\"491,-149.5 379,-149.5 379,-96.5 491,-96.5 491,-149.5\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"435\" y=\"-134.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">gini = 0.0</text>\r\n",
       "<text text-anchor=\"middle\" x=\"435\" y=\"-119.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">samples = 30</text>\r\n",
       "<text text-anchor=\"middle\" x=\"435\" y=\"-104.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">value = [0, 0, 30]</text>\r\n",
       "</g>\r\n",
       "<!-- 8&#45;&gt;12 -->\r\n",
       "<g id=\"edge12\" class=\"edge\"><title>8&#45;&gt;12</title>\r\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M350.584,-192.884C365.14,-181.006 381.512,-167.646 395.936,-155.876\"/>\r\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"398.191,-158.554 403.726,-149.52 393.765,-153.13 398.191,-158.554\"/>\r\n",
       "</g>\r\n",
       "<!-- 10 -->\r\n",
       "<g id=\"node11\" class=\"node\"><title>10</title>\r\n",
       "<polygon fill=\"none\" stroke=\"black\" points=\"359.5,-53 254.5,-53 254.5,-0 359.5,-0 359.5,-53\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"307\" y=\"-37.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">gini = 0.0</text>\r\n",
       "<text text-anchor=\"middle\" x=\"307\" y=\"-22.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">samples = 1</text>\r\n",
       "<text text-anchor=\"middle\" x=\"307\" y=\"-7.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">value = [0, 1, 0]</text>\r\n",
       "</g>\r\n",
       "<!-- 9&#45;&gt;10 -->\r\n",
       "<g id=\"edge10\" class=\"edge\"><title>9&#45;&gt;10</title>\r\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M307.65,-88.9485C307.563,-80.7153 307.469,-71.848 307.381,-63.4814\"/>\r\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"310.878,-63.1991 307.272,-53.2367 303.878,-63.2732 310.878,-63.1991\"/>\r\n",
       "</g>\r\n",
       "<!-- 11 -->\r\n",
       "<g id=\"node12\" class=\"node\"><title>11</title>\r\n",
       "<polygon fill=\"none\" stroke=\"black\" points=\"482.5,-53 377.5,-53 377.5,-0 482.5,-0 482.5,-53\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"430\" y=\"-37.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">gini = 0.0</text>\r\n",
       "<text text-anchor=\"middle\" x=\"430\" y=\"-22.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">samples = 2</text>\r\n",
       "<text text-anchor=\"middle\" x=\"430\" y=\"-7.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">value = [0, 0, 2]</text>\r\n",
       "</g>\r\n",
       "<!-- 9&#45;&gt;11 -->\r\n",
       "<g id=\"edge11\" class=\"edge\"><title>9&#45;&gt;11</title>\r\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M350.67,-88.9485C362.952,-79.4346 376.328,-69.074 388.536,-59.6175\"/>\r\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"391.011,-62.1274 396.774,-53.2367 386.725,-56.5934 391.011,-62.1274\"/>\r\n",
       "</g>\r\n",
       "</g>\r\n",
       "</svg>\r\n"
      ],
      "text/plain": [
       "<graphviz.files.Source at 0x27c5b1a0988>"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# encoding=utf-8\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.metrics import accuracy_score\n",
    "from sklearn import tree\n",
    "from sklearn.datasets import load_iris\n",
    "import sys\n",
    "import os\n",
    "import graphviz\n",
    "os.environ[\"PATH\"] += os.pathsep + 'D:/Apps/Anaconda3/Library/bin/graphviz'\n",
    "# 准备数据集\n",
    "iris=load_iris()\n",
    "# 获取特征集和分类标识\n",
    "features = iris.data\n",
    "labels = iris.target\n",
    "# 随机抽取33%的数据作为测试集，其余为训练集\n",
    "train_features, test_features, train_labels, test_labels = train_test_split(features, labels, test_size=0.33, random_state=0)\n",
    "# 创建CART分类树\n",
    "clf = tree.DecisionTreeClassifier(criterion='gini')\n",
    "# 拟合构造CART分类树\n",
    "clf = clf.fit(train_features, train_labels)\n",
    "# 用CART分类树做预测\n",
    "test_predict = clf.predict(test_features)\n",
    "# 预测结果与测试集结果作比对\n",
    "score = accuracy_score(test_labels, test_predict)\n",
    "print(\"CART分类树准确率 %.4lf\" % score)\n",
    "\n",
    "dot_data = tree.export_graphviz(clf,out_file=None)\n",
    "graph = graphviz.Source(dot_data)\n",
    "graph"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
